{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Road Following - Live demo (Jetson-category TensorRT) with Jetbot anti-collision trained TRT model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Jetracer cagtegory scritpt to operate jetbot interference  and chose category (e.g. follow street or object) to run the Jetbot with Anti-collision trained Jetbot model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the TRT optimized model by executing the cell below - # TensorRT\n",
    "## NEED TO CHOSE CATEGORY AND PRE TRAINED MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "device = torch.device('cuda')\n",
    "\n",
    "#CATEGORIES = ['apex','bottle']\n",
    "\n",
    "CATEGORIES = ['road','bottle']\n",
    "\n",
    "# if no Categories then activate like Jetbot trained model then:\n",
    "#CATEGORIES = []\n",
    "#model_trt.load_state_dict(torch.load('best_steering_model_xy_trt.pth')) #jetbot model from training and build TRT without category\n",
    "\n",
    "model_trt = TRTModule()\n",
    "model_trt.load_state_dict(torch.load('road_following_model_trt.pth')) #jetracer for jetbot model from training and build with TRT\n",
    "\n",
    "model_trt_collision = TRTModule()\n",
    "model_trt_collision.load_state_dict(torch.load('best_model_trt.pth')) # anti collision model traubed for one object to block and street signals (ground, strips) free\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the Pre-Processing Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now loaded our model, but there's a slight issue. The format that we trained our model doesnt exactly match the format of the camera. To do that, we need to do some preprocessing. This involves the following steps:\n",
    "\n",
    "1. Convert from HWC layout to CHW layout\n",
    "2. Normalize using same parameters as we did during training (our camera provides values in [0, 255] range and training loaded images in [0, 1] range so we need to scale by 255.0\n",
    "3. Transfer the data from CPU memory to GPU memory\n",
    "4. Add a batch dimension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torch.nn.functional as F\n",
    "import cv2\n",
    "import PIL.Image\n",
    "import numpy as np\n",
    "\n",
    "mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().half()\n",
    "std = torch.Tensor([0.229, 0.224, 0.225]).cuda().half()\n",
    "\n",
    "def preprocess(image):\n",
    "    image = PIL.Image.fromarray(image)\n",
    "    image = transforms.functional.to_tensor(image).to(device).half()\n",
    "    image.sub_(mean[:, None, None]).div_(std[:, None, None])\n",
    "    return image[None, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetbot import Camera, bgr8_to_jpeg\n",
    "\n",
    "camera = Camera()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets\n",
    "import ipywidgets.widgets as widgets\n",
    "import time\n",
    "import IPython\n",
    "from IPython.display import display\n",
    "import traitlets\n",
    "\n",
    "target_widget = widgets.Image(format='jpeg', width=224, height=224)\n",
    "x_slider = ipywidgets.FloatSlider(min=-1.0, max=1.0, description='x')\n",
    "y_slider = ipywidgets.FloatSlider(min=0, max=1.0, orientation='horizontal', description='y')\n",
    "\n",
    "def display_xy(camera_image):\n",
    "    image = np.copy(camera_image)\n",
    "    x = x_slider.value\n",
    "    y = y_slider.value\n",
    "    x = int(x * 224 / 2 + 112)\n",
    "    y = int(y * -224 / 2 + 112)\n",
    "    image = cv2.circle(image, (x, y), 8, (0, 255, 0), 3)\n",
    "    image = cv2.circle(image, (112, 224), 8, (0, 0,255), 3)\n",
    "    image = cv2.line(image, (x,y), (112,224), (255,0,0), 3)\n",
    "    jpeg_image = bgr8_to_jpeg(image)\n",
    "    return jpeg_image\n",
    "\n",
    "#time.sleep(1)\n",
    "traitlets.dlink((camera, 'value'), (target_widget, 'value'), transform=display_xy)\n",
    "\n",
    "\n",
    "display(widgets.HBox([target_widget]))\n",
    "d2 = IPython.display.display(\"\", display_id=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also create our robot instance which we'll need to drive the motors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetbot import Robot\n",
    "\n",
    "robot = Robot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "speed_gain_slider = ipywidgets.FloatSlider(min=0.0, max=1.0, step=0.01, description='speed gain')\n",
    "steering_gain_slider = ipywidgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.09, description='steering gain')\n",
    "steering_dgain_slider = ipywidgets.FloatSlider(min=0.0, max=0.5, step=0.001, value=0.24, description='steering kd')\n",
    "steering_bias_slider = ipywidgets.FloatSlider(min=-0.3, max=0.3, step=0.01, value=0.0, description='steering bias')\n",
    "\n",
    "display(speed_gain_slider, steering_gain_slider, steering_dgain_slider, steering_bias_slider)\n",
    "\n",
    "blocked_slider = ipywidgets.FloatSlider(description='blocked', min=0.0, max=120.0, orientation='horizontal')\n",
    "stopduration_slider= ipywidgets.IntSlider(min=1, max=1000, step=1, value=200 description='Manu. time stop') #anti-collision stop time\n",
    "block_threshold= ipywidgets.FloatSlider(min=0, max=2, step=0.1, value=0.9, description='Manu. bl threshold') #anti-collision block probability\n",
    "\n",
    "display(ipywidgets.HBox([ blocked_slider, stopduration_slider, block_threshold]))\n",
    "\n",
    "#create new New View for Output\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's display some sliders that will let us see what JetBot is thinking.  The x and y sliders will display the predicted x, y values.\n",
    "\n",
    "The steering slider will display our estimated steering value.  Please remember, this value isn't the actual angle of the target, but simply a value that is\n",
    "nearly proportional.  When the actual angle is ``0``, this will be zero, and it will increase / decrease with the actual angle.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "steering_slider = ipywidgets.FloatSlider(min=-1.0, max=1.0, description='steering')\n",
    "speed_slider = ipywidgets.FloatSlider(min=0, max=1.0, orientation='horizontal', description='speed')\n",
    "category_widget = ipywidgets.Dropdown(options=np.array(CATEGORIES), description='category')\n",
    "\n",
    "\n",
    "display(y_slider, x_slider, speed_slider, steering_slider)\n",
    "\n",
    "#choose category for road or object following\n",
    "category_execution_widget = ipywidgets.VBox([category_widget])\n",
    "\n",
    "display(category_execution_widget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_category(change):\n",
    "    global category_index\n",
    "    if not CATEGORIES: \n",
    "        category_index=0\n",
    "    else:\n",
    "        category_index=CATEGORIES.index(category_widget.value)\n",
    "    return\n",
    "\n",
    "category_widget.observe(start_category, names='value')\n",
    "\n",
    "#repeated for initialization\n",
    "if not CATEGORIES: \n",
    "    print(\"List is empty.\")\n",
    "    category_index=0;\n",
    "else:\n",
    "    category_index=CATEGORIES.index(category_widget.value)\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll create a function that will get called whenever the camera's value changes. This function will do the following steps\n",
    "\n",
    "1. Pre-process the camera image\n",
    "2. Execute the neural network\n",
    "3. Compute the approximate steering value\n",
    "4. Control the motors using proportional / derivative control (PD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import models\n",
    "from torch.hub import load_state_dict_from_url\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "\n",
    "count=0\n",
    "count_stops=0\n",
    "stop_time=50 #(for how many frames the bot should go backwards, see and of script)\n",
    "angle = 0.0\n",
    "angle_last = 0.0\n",
    "go_on=1\n",
    "max_x=camera.width\n",
    "min_x=0\n",
    "max_y=camera.height\n",
    "min_y=0\n",
    "sum_bottle=0.0\n",
    "x=0\n",
    "y=0\n",
    "\n",
    "def execute(change):\n",
    "    global angle, angle_last, category_index, count, count_stops,stop_time,go_on,block_threshold,stop_time,max_x, min_x,max_y,min_y,x,y\n",
    "    count +=1\n",
    "    t1 = time.time()\n",
    "    \n",
    "    image = change['new']\n",
    "    image_preproc = preprocess(image)\n",
    "    #anti_collision model)-----\n",
    "    y_collison=model_trt_collision(image_preproc)\n",
    "    \n",
    "    # we apply the `softmax` function to normalize the output vector so it sums to 1 (which makes it a probability distribution)\n",
    "    y_collison = F.softmax(y_collison, dim=1)    \n",
    "    prob_blocked = float(y_collison.flatten()[0])\n",
    "    blocked_slider.value = prob_blocked\n",
    "    stop_time=stopduration_slider.value\n",
    "   \n",
    "    if go_on==1:\n",
    "        if prob_blocked > block_threshold.value: #in case collision avoidance model hast recognized an object (not specific one) (prob_blocked) then stop for some time about ~ frames*0.05ms \n",
    "            count_stops +=1\n",
    "            x=0.0 #set steering zero\n",
    "            y=0.0 #set steering zero\n",
    "            speed_slider.value=0.0 # set speed zero or negative or turn\n",
    "            go_on=2\n",
    "            #anti_collision-------\n",
    "        else: #start road following if no object\n",
    "            go_on=1\n",
    "            count_stops=0\n",
    "            #-------------------\n",
    "            xy = model_trt(image_preproc).detach().float().cpu().numpy().flatten()  \n",
    "            x = float(xy[2 * category_index])\n",
    "            y = float(xy[2 * category_index + 1] )  \n",
    "            x = int(max_x * (x / 2.0 + 0.5))\n",
    "            y = int(max_y * (y / 2.0 + 0.5))\n",
    "            speed_slider.value = speed_gain_slider.value #\n",
    "    else:\n",
    "        count_stops=count_stops+1\n",
    "        if count_stops<stop_time: #how many frames bot should pause\n",
    "            x=0.0 #set steering zero\n",
    "            y=0 #set steering zero\n",
    "            speed_slider.value=0 # set speed zero or negative or turn\n",
    "        else:\n",
    "            go_on=1\n",
    "            count_stops=0\n",
    "        \n",
    "    \n",
    "    x_joysticklike=((x-max_x/2.0)-min_x)/(max_x-min_x)\n",
    "    y_joysticklike=((max_y-y)-min_y)/(max_y-min_y)     \n",
    "    x_slider.value = x_joysticklike\n",
    "    y_slider.value = y_joysticklike\n",
    "    \n",
    "    #---------        \n",
    "        \n",
    "    angle = np.arctan2(x_joysticklike, y_joysticklike)\n",
    "    pid = angle * steering_gain_slider.value + (angle - angle_last) * steering_dgain_slider.value\n",
    "    angle_last = angle \n",
    "    steering_slider.value = pid + steering_bias_slider.value\n",
    "    robot.left_motor.value = max(min(speed_slider.value + steering_slider.value, 1.0), 0.0)\n",
    "    robot.right_motor.value = max(min(speed_slider.value - steering_slider.value, 1.0), 0.0)\n",
    "    \n",
    "    #---------    \n",
    "    \n",
    "    t2 = time.time()\n",
    "    s = f\"\"\"{int(1/(t2-t1))} FPS\"\"\"\n",
    "    d2.update(IPython.display.HTML(s) )\n",
    "    \n",
    "    \n",
    "execute({'new': camera.value})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The jupyter plot shows the x and y values, x value is more variable, y value is more or less the same (y might be useful later for velocity regulation),\n",
    "The initialization of jupyterplot takes >30 sec, so need to wait, once 0 FPSS appears, activate system with camera.observe below\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cool! We've created our neural network execution function, but now we need to attach it to the camera for processing.\n",
    "\n",
    "We accomplish that with the observe function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">WARNING: This code will move the robot!! Please make sure your robot has clearance and it is on Lego or Track you have collected data on. The road follower should work, but the neural network is only as good as the data it's trained on!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Awesome! If your robot is plugged in it should now be generating new commands with each new camera frame. \n",
    "\n",
    "You can now place JetBot on  Lego or Track you have collected data on and see whether it can follow track.\n",
    "\n",
    "If you want to stop this behavior, you can unattach this callback by executing the code below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "camera.unobserve(execute, names='value')\n",
    "time.sleep(0.1)  # add a small sleep to make sure frames have finished processing\n",
    "robot.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conclusion\n",
    "That's it for this live demo! Hopefully you had some fun seeing your JetBot moving smoothly on track follwing the road!!!\n",
    "\n",
    "If your JetBot wasn't following road very well, try to spot where it fails. The beauty is that we can collect more data for these failure scenarios and the JetBot should get even better :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
